{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Resnet 50 Training on (Fake)Imagenet with WebDataset\n",
    "\n",
    "This notebook illustrates how to use WebDataset with PyTorch training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-14T05:38:33.429679Z",
     "iopub.status.busy": "2023-12-14T05:38:33.429158Z",
     "iopub.status.idle": "2023-12-14T05:38:36.121794Z",
     "shell.execute_reply": "2023-12-14T05:38:36.120859Z"
    }
   },
   "outputs": [],
   "source": [
    "# imports\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "from functools import partial\n",
    "from pprint import pprint\n",
    "import random\n",
    "from collections import deque\n",
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "import sys\n",
    "import os\n",
    "\n",
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "from torchvision.models import resnet50\n",
    "from torch.utils.data import DataLoader\n",
    "from torch import nn, optim\n",
    "\n",
    "# helpers\n",
    "\n",
    "import time\n",
    "\n",
    "\n",
    "def enumerate_report(seq, delta, growth=1.0):\n",
    "    last = 0\n",
    "    count = 0\n",
    "    for count, item in enumerate(seq):\n",
    "        now = time.time()\n",
    "        if now - last > delta:\n",
    "            last = now\n",
    "            yield count, item, True\n",
    "        else:\n",
    "            yield count, item, False\n",
    "        delta *= growth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-14T05:38:36.126732Z",
     "iopub.status.busy": "2023-12-14T05:38:36.126472Z",
     "iopub.status.idle": "2023-12-14T05:38:36.138014Z",
     "shell.execute_reply": "2023-12-14T05:38:36.137312Z"
    }
   },
   "outputs": [],
   "source": [
    "# We usually abbreviate webdataset as wds\n",
    "import webdataset as wds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# parameters\n",
    "epochs = 1\n",
    "max_steps = int(1e12)\n",
    "batchsize = 32\n",
    "bucket = \"https://storage.googleapis.com/webdataset/fake-imagenet\"\n",
    "training_urls = bucket + \"/imagenet-train-{000000..001281}.tar\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Loader Construction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-14T05:38:36.150848Z",
     "iopub.status.busy": "2023-12-14T05:38:36.150313Z",
     "iopub.status.idle": "2023-12-14T05:38:36.278365Z",
     "shell.execute_reply": "2023-12-14T05:38:36.276943Z"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "# WebDataset is designed to work without any local storage. Use caching\n",
    "# only if you are on a desktop with slow networking.\n",
    "\n",
    "if 'google.colab' in sys.modules:\n",
    "    cache_dir = None\n",
    "    print(\"running on colab, streaming data directly from storage\")\n",
    "else:\n",
    "    !mkdir -p ./_cache\n",
    "    cache_dir = \"./_cache\"\n",
    "    print(f\"not running in colab, caching data locally in {cache_dir}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-14T05:38:36.356785Z",
     "iopub.status.busy": "2023-12-14T05:38:36.356004Z",
     "iopub.status.idle": "2023-12-14T05:38:36.368348Z",
     "shell.execute_reply": "2023-12-14T05:38:36.366853Z"
    }
   },
   "outputs": [],
   "source": [
    "# The standard TorchVision transformations.\n",
    "\n",
    "transform_train = transforms.Compose(\n",
    "    [\n",
    "        transforms.RandomResizedCrop(224),\n",
    "        transforms.RandomHorizontalFlip(),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "    ]\n",
    ")\n",
    "\n",
    "\n",
    "def make_sample(sample, val=False):\n",
    "    \"\"\"Take a decoded sample dictionary, augment it, and return an (image, label) tuple.\"\"\"\n",
    "    assert not val, \"only implemented training dataset for this notebook\"\n",
    "    image = sample[\"jpg\"]\n",
    "    label = sample[\"cls\"]\n",
    "    return transform_train(image), label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-14T05:38:36.374043Z",
     "iopub.status.busy": "2023-12-14T05:38:36.372928Z",
     "iopub.status.idle": "2023-12-14T05:38:36.396581Z",
     "shell.execute_reply": "2023-12-14T05:38:36.395381Z"
    }
   },
   "outputs": [],
   "source": [
    "# Create the datasets with shard and sample shuffling and decoding.\n",
    "trainset = wds.WebDataset(\n",
    "    training_urls, resampled=True, cache_dir=cache_dir, shardshuffle=True\n",
    ")\n",
    "trainset = trainset.shuffle(1000).decode(\"pil\").map(make_sample)\n",
    "\n",
    "# Since this is an IterableDataset, PyTorch requires that we batch in the dataset.\n",
    "# WebLoader is PyTorch DataLoader with some convenience methods.\n",
    "trainset = trainset.batched(64)\n",
    "trainloader = wds.WebLoader(trainset, batch_size=None, num_workers=4)\n",
    "\n",
    "# Unbatch, shuffle between workers, then rebatch.\n",
    "trainloader = trainloader.unbatched().shuffle(1000).batched(64)\n",
    "\n",
    "# Since we are using resampling, the dataset is infinite; set an artificial epoch size.\n",
    "trainloader = trainloader.with_epoch(1282 * 100 // 64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-14T05:38:36.402668Z",
     "iopub.status.busy": "2023-12-14T05:38:36.401448Z",
     "iopub.status.idle": "2023-12-14T05:38:37.694483Z",
     "shell.execute_reply": "2023-12-14T05:38:37.693554Z"
    }
   },
   "outputs": [],
   "source": [
    "# Smoke test it.\n",
    "\n",
    "os.environ[\"GOPEN_VERBOSE\"] = \"1\"\n",
    "images, classes = next(iter(trainloader))\n",
    "print(images.shape, classes.shape)\n",
    "os.environ[\"GOPEN_VERBOSE\"] = \"0\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PyTorch Training\n",
    "\n",
    "This is a typical PyTorch training pipeline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-14T05:38:37.699369Z",
     "iopub.status.busy": "2023-12-14T05:38:37.698777Z",
     "iopub.status.idle": "2023-12-14T05:38:39.799167Z",
     "shell.execute_reply": "2023-12-14T05:38:39.798175Z"
    }
   },
   "outputs": [],
   "source": [
    "# The usual PyTorch model definition. We use an uninitialized ResNet50 model.\n",
    "\n",
    "model = resnet50(pretrained=False)\n",
    "\n",
    "# Define the loss function and optimizer\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)\n",
    "\n",
    "# Move the model to the GPU if available\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-14T05:38:39.804002Z",
     "iopub.status.busy": "2023-12-14T05:38:39.803809Z",
     "iopub.status.idle": "2023-12-14T06:43:38.731276Z",
     "shell.execute_reply": "2023-12-14T06:43:38.730369Z"
    }
   },
   "outputs": [],
   "source": [
    "losses, accuracies = deque(maxlen=100), deque(maxlen=100)\n",
    "\n",
    "steps = 0\n",
    "\n",
    "# Train the model\n",
    "for epoch in range(epochs):\n",
    "    for i, data, verbose in enumerate_report(trainloader, 5):\n",
    "        # get the inputs; data is a list of [inputs, labels]\n",
    "        inputs, labels = data[0].to(device), data[1].to(device)\n",
    "\n",
    "        # zero the parameter gradients\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # forward + backward + optimize\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        pred = outputs.cpu().detach().argmax(dim=1, keepdim=True)\n",
    "        correct = pred.eq(labels.cpu().view_as(pred)).sum().item()\n",
    "        accuracy = correct / float(len(labels))\n",
    "\n",
    "        losses.append(loss.item())\n",
    "        accuracies.append(accuracy)\n",
    "        steps += len(inputs)\n",
    "\n",
    "        if verbose and len(losses) > 5:\n",
    "            print(\n",
    "                \"[%d, %5d] loss: %.5f correct: %.5f\"\n",
    "                % (epoch + 1, i + 1, np.mean(losses), np.mean(accuracies))\n",
    "            )\n",
    "            running_loss = 0.0\n",
    "\n",
    "        if steps > max_steps:\n",
    "            break\n",
    "\n",
    "    if steps > max_steps:\n",
    "        break\n",
    "\n",
    "print(\"Finished Training\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
